-
Notifications
You must be signed in to change notification settings - Fork 68
Add CausalMask support with new flash attention api #604
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp
Outdated
Show resolved
Hide resolved
applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds support for causal masking in the flash attention implementation by introducing a new SubgroupLayoutQK template parameter and implementing the causal mask logic in the mainloop.
Key Changes:
- Added
SubgroupLayoutQKtemplate parameter to the collective mainloop and kernel interfaces - Implemented causal masking logic that applies
-INFINITYto attention scores beyond the causal boundary - Updated the example runner to conditionally instantiate causal or non-causal configurations based on user options
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp |
Implements causal mask logic and removes the static assertion that previously blocked causal mask usage |
applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp |
Adds subgroup layout type alias and computes sequence coordinates for causal masking |
examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp |
Adds SubgroupLayoutQK template parameter to mainloop type |
examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp |
Conditionally selects causal or non-causal kernel based on is_causal option |
applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp
Outdated
Show resolved
Hide resolved
Signed-off-by: Chen, Xi2 <[email protected]>
Signed-off-by: Chen, Xi2 <[email protected]>
bb07ccc to
836f2c4
Compare
Signed-off-by: Chen, Xi2 <[email protected]>
836f2c4 to
21a1bce
Compare
| } | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ClarkChin08 Thanks for updating this! By the way, we can make the code even cleaner by including the block offset in cS_thread itself. Something like this should do it:
Tensor gP = local_tile(cP, TileShapeQK{}, blk_qv);
auto cS_thread = thr_mma_qk.partition_C(gP);
Then you don't need to do the blocking calculations here; instead row_idx = get<0>(cS_thread(i)), col_idx = get<1>(cS_thread(i)).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @petercad , I changed to use local_tile to get global col and row indices.
include/cute/atom/mma_atom.hpp
Outdated
|
|
||
| CUTE_HOST_DEVICE constexpr auto | ||
| get_atom_layout_mnk() const { | ||
| return atom_layout_mnk_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of adding a new atom_layout_mnk_ member:
| return atom_layout_mnk_; | |
| return AtomLayoutMNK{}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
| auto discard_seq_coord = s.seq_len_qo - offset; | ||
| auto full_tile_offset = s.seq_len_kv - offset; | ||
|
|
||
| int seq_coord = cute::min(s.seq_len_qo, (blk_q * get<0>(TileShapeQK{}) + (sub_group_id / get<1>(shape(SubgroupLayoutQK{}))) * SGTileQ)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The sub_group_id / get<1>(shape(SubgroupLayoutQK{})) part is making a strong assumption about how subgroup tiles are arranged within the workgroup tile (K-major). We need to either add a static_assert for this condition, or (better) use CuTe layout algebra to calculate the subgroup Q offset. For instance:
auto cS = make_identity_tensor(take<0,2>(TiledMMAKQ{}.tile_size()));
auto tScS = TiledMMAKQ{}.get_slice(thread_idx).partition_C(cS);
auto q_offset_wi = get<0>(tScS(0)); /* Q offset for thread */
auto q_offset_sg = group_broadcast(sycl::ext::oneapi::this_work_item::get_sub_group(), q_offset_wi, 0); /* Q offset for SG */There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, @petercad. I'm now implementing the algebraic approach you suggested for calculating q_offset_sg.
Signed-off-by: Chen, Xi2 <[email protected]>
Signed-off-by: Chen, Xi2 <[email protected]>
No description provided.